#! /usr/bin/env python
## This file is part of biopy.
## Copyright (C) 2010 Joseph Heled
## Author: Joseph Heled <jheled@gmail.com>
## See the files gpl.txt and lgpl.txt for copying conditions.
#

from __future__ import division

import argparse, sys, os.path
from time import time

from biopy.genericutils import fileFromName
from biopy.treeutils import toNewick, countNexusTrees, nodeHeights, getCommonAncesstor

import biopy.treesPosterior
minDistanceTree = biopy.treesPosterior.minDistanceTree

from biopy import INexus
from biopy import beastXMLhelper

class DistanceAction(argparse.Action):
  def __call__(self, parser, names, values, option_string=None):
    val = getattr(biopy.treesPosterior, values.replace('-','_').upper())
    setattr(names, self.dest, val)
    
parser = argparse.ArgumentParser(description=""" %prog [OPTIONS] posterior-trees.nexus 

  Generate a single summary tree for a set of posterior trees.""")

parser.add_argument("-b", "--burnin", type = float,
                    help="Burn-in amount (percent, default %(default)g)", default = 10)

parser.add_argument("--every", "-e", metavar="E", type = int,
                    help="""thin out - take one tree for every E. Especially
                    useful if you run out of memory (default all,
                    i.e. %(default)d)""", default = 1)

parser.add_argument("--sort-topologies-by", dest="sortby", metavar="METHOD", 
                    choices=["highest-posterior", "conditional", "mcc", "tcb"],
                    default = "mcc",
                    help="""Try topologies in order determined by this""" +
                    """ criteria. mcc better for likelihood, tcb for truth.
                    mdtp: mean-distance-to-posterior. """ +
                    """tcb,mcc,conditional,highest-posterior default %(default)s""")

parser.add_argument("--search", dest="serachTopologies",
                    help="""Search in topology space based on zero length""" +
                    """ present in optimized trees.""",
                    action="store_true", default = False)

parser.add_argument("--method", dest="method", metavar="METHOD", 
                    choices=["min-distance", "clade-ca", "median-heights",
                             "taxon-partitions"],
                    default = "min-distance",
                    help="""method to use: min-distance,clade-ca,median-heights """ +
                    """taxon-partitions (default min-distance)""")

parser.add_argument("--distance-method", dest="distance", metavar="METHOD", 
                    choices=["branch-score", "branch-score-2", "heights-score", "heights-only", "rooted-agreement"],
                    action = DistanceAction, default = biopy.treesPosterior.BRANCH_SCORE,
                    help="""Tree distance method to use: branch-score,branch-score-2,""" +
                    """heights-score,heights-only,rooted-agreement (default branch-score)""")

parser.add_argument("--ntops", "-n", metavar="N", type = int,
                    help="""Use the top N topologies from the posterior (default
                    %(default)d)""", default = 1)

parser.add_argument("--limit", "-l", metavar="S", type = int,
                    help="""run at most S seconds, trying out topologies in""" + \
                    """ decreasing order of support and in random order for""" + \
                    """ topologies with equal support (default -1, i.e no""" + \
                    """ time limits). Note that timing code has not been""" + \
                    """ tested under M$windows or OSX """, 
                    default = -1) 

parser.add_argument("--topology",
                    help="""Try only this topology""", default = None)

parser.add_argument("--score", help="Print score as a tree attribute",
                    action="store_true", default = False)

parser.add_argument("--laplace-correction", dest = "correction", help="",
                    action="store_true", default = False)

parser.add_argument("--matching",
                    help="With --topology, use only posterior trees with an """ +
                    """identical topology to the target.""",
                    action="store_true", default = False)

parser.add_argument("--compatible", metavar="TREE",
                    help="Insure gene tree is compatible with species TREE.",
                    default = None)

parser.add_argument("--species-mapping", dest="spmap", metavar="FILE",
                    help=("A file contatining the mapping of gene lineages to" + 
                          " species for the --compatible option (currently " + 
                          "only BEAST(1) XML file supported). Otherwise, a simple" +
                          " scheme of looking for the species name inside the gene" +
                          " name is used."),
                    default = None)

parser.add_argument("--annotate", "-a", dest="annotate",
                    help="Add support metadata",
                    action="store_true", default = False)

parser.add_argument("-p", "--progress", dest="progress",
                    help="Print out progress messages to terminal (standard error)",
                    action="store_true", default = False)

parser.add_argument("--no-derivative", dest="derivative",
                    help="Do not use analytic derivative in optimization function",
                    action="store_false", default = True)

parser.add_argument("--no-norm", dest="norm",
                    help="",
                    action="store_false", default = True)

parser.add_argument('trees', metavar='FILE',  help="Trees file (NEXUS)")

options = parser.parse_args()
nexusTreesFileName = options.trees
try :
  nexFile = fileFromName(nexusTreesFileName)
except Exception,e:
  # report error
  print >> sys.stderr, "Error:", e.message
  sys.exit(1)

progress = options.progress
burnIn = options.burnin
if 0 < burnIn < 1 :
  print >> sys.stderr, ("*** Warning ***: tiny value for burn-in. Burn-in is"
                        " given as a percent in the range zero to 100.")
burnIn /= 100.0

optimizingMethod = options.method == "min-distance"

every = options.every
ntops = options.ntops
limit = options.limit

if not optimizingMethod:
  ntops = 1
else :
  useOnlyMatchingTopologies = options.matching

  if options.topology is not None :
    if '(' not in options.topology:
      print >> sys.stderr,"topology is not a tree? (contains no '(')"
      sys.exit(1)
    try :
      options.topology = INexus.Tree(options.topology)
    except Exception,e:
      print >> sys.stderr, "Error in parsing tree:", e.message
      sys.exit(1)
  else :
    if useOnlyMatchingTopologies:
      print >> sys.stderr, """WARNING: Using --matching without --topology is\
   dubious. procceding anyway..."""

  if options.compatible is not None :
    if os.path.isfile(options.compatible) :
      spTree = INexus.INexus(options.compatible).trees[0]
    else :
      try :
        spTree = INexus.Tree(options.compatible)
      except Exception,e:
        print >> sys.stderr, "*** Error: failed to read tree: ",\
              e.message,"[",options.compatible,"]"
        sys.exit(1)

    if options.spmap is not None :
      bd = {'species' : None}
      bsps = beastXMLhelper.readBeastFile(options.spmap, bd)
      speciesMapping = bd['species']['species']

      def getSpeciesOf(name) :
        for spName in speciesMapping :
          if name in speciesMapping[spName] :
            return spName
        return None

    else :

      spTerms = spTree.get_terminals()
      spNames = [spTree.node(x).data.taxon for x in spTerms]

      def getSpeciesOf(name) :
        b = [x in name for x in spNames]
        if sum(b) == 1 :
          sid = b.index(True)
          return spNames[sid]
        return None

    def _setSPS(tr, nid) :
      """ Store the set of species spanned by the subtree in each node.

      At the end, each node in tree 'tr' contains data.species whixh holds the set
      species names as strings.
      """
      n = tr.node(nid)

      if not len(n.succ) :
        # leaf
        s = getSpeciesOf(n.data.taxon)
        assert s is not None, n.data.taxon
        allsp = frozenset([s])
      else :
        # internal node
        l = [_setSPS(tr, x) for x in n.succ]
        allsp = reduce(lambda x,y : x.union(y), l)

      n.data.species = allsp
      return n.data.species

    def getSPlimits(gtr, spt = spTree) :
      """ Establish a mapping between each internal node of the gene tree 'gtr' and
      its minimum height according to species tree 'spt'.
      """

      nh = nodeHeights(gtr, allTipsZero = False)
      snh = nodeHeights(spt, allTipsZero = False)

      _setSPS(gtr, gtr.root)
      for i in gtr.all_ids() :
        n = gtr.node(i)
        if len(n.succ) :
          if len(n.data.species) > 1 :
            stx = [spt.search_taxon(x) for x in n.data.species]        
            k = getCommonAncesstor(spt, stx)
            b = snh[k]
          else :
            b = 0
          nh[i] = b 
      return nh
    
if progress:
  print >> sys.stderr, "counting trees ...,",
nTrees = countNexusTrees(nexusTreesFileName)

# establish trees

nexusReader = INexus.INexus(simple=True)

nBurninTrees = int(burnIn*nTrees)

if progress:
  print >> sys.stderr, "reading %d trees ...," % ((nTrees - nBurninTrees)//every),

try :
  trees = list(nexusReader.read(nexFile, slice(nBurninTrees, -1, every)))
except Exception,e:
  print >> sys.stderr, "**Problem reading trees file:", e.message
  sys.exit(1)

if len(trees) == 0 :
  print >> sys.stderr, "**Error: No trees read."
  sys.exit(1)

if 1 :
  # Save some memory when trees contains attributes, small speed penalty
  for tree in trees :
    for nid in tree.all_ids() :
      data = tree.node(nid).data
      if hasattr(data, "attributes") :
        del data.attributes


from biopy.treeMeasure import cladesInTreesSet, conditionalCladeScore
from biopy.treeutils import getTreeClades, treeHeight
from math import log

import copy, itertools, random
def getNNIcandidates(tr) :
  cans = [x for x in tr.all_ids() if x != tr.root and tr.node(x).data.branchlength == 0]
  if len(cans) :
    #random.shuffle(cans)
    allc = []

    tr = copy.deepcopy(tr)
    for n in cans :
      nn = tr.node(n)
      if not nn.succ:
        continue
      nn = tr.node(n)
      side = random.choice((0,1))
      ch = nn.succ[side]
      tr.node(ch).prev = nn.prev
      par = nn.prev
      npar = tr.node(par)
      i = npar.succ.index(n)
      nn.succ[side] = npar.succ[1-i]
      tr.node(nn.succ[side]).prev = n
      npar.succ[1-i] = ch
    allc.append(tr)
    
    for n in cans :
      nn = tr.node(n)
      if not nn.succ:
        continue
      tr = copy.deepcopy(tr)
      nn = tr.node(n)
      side = random.choice((0,1))
      if 1:
        ch = nn.succ[side]
        tr.node(ch).prev = nn.prev
        par = nn.prev
        npar = tr.node(par)
        i = npar.succ.index(n)
        nn.succ[side] = npar.succ[1-i]
        tr.node(nn.succ[side]).prev = n
        npar.succ[1-i] = ch
        allc.append(tr)
        
    return allc
  return []

def sortTopologies(allTops, atMost) :
  trees = reduce(lambda x,y : x+y, [trs for top,trs in allTops])
  if options.sortby == 'tcb' :
    cl = cladesInTreesSet(trees, func = lambda n : n.branchlength)
    for k in cl :
      cl[k] = sum(cl[k])
    tcb = lambda (top,trs) : sum([cl[frozenset(c)]
                                  for c,node in getTreeClades(trs[0], False)])
    allTops.sort(key = tcb, reverse=1)
  elif options.sortby == 'mcc' :
    cl = cladesInTreesSet(trees)
    mcc = lambda (top,trs) : sum([log(cl[frozenset(c)])
                                for c,node in getTreeClades(trs[0], False)])
    allTops.sort(key = mcc, reverse=1)
  elif options.sortby == "highest-posterior":
    # Sort by amount of support
    allTops.sort(reverse=1, key = lambda (top,trs) : len(trs))

    # sort each group with same support by distance of root to median root
    allh = sorted([treeHeight(t) for t in trees])
    hmed = allh[len(allh)//2]
    tot = 0
    for i in range(len(allTops)) :
      top,trs = allTops[i]
      # not efficient
      trs.sort(key = lambda h : abs(treeHeight(h)-hmed))
      tot += len(trs)
      if tot >= atMost :
        break
      
  elif options.sortby == "conditional":
    clc2 = cladesInTreesSet(trees, 1)
    lar = lambda (top,trs) : conditionalCladeScore(trs[0], clc2, options.correction)
    allTops.sort(key = lar, reverse=1)
  else :
    raise ValueError(options.sortby)
  
  return allTops

if options.method != "taxon-partitions" :
  if progress:
    print >> sys.stderr, "collect topologies and sort ...,",
  topology = dict()
  for tree in trees :
    k = toNewick(tree, None, topologyOnly=True)
    if k not in topology :
      topology[k] = [tree,]
    else :
      topology[k].append(tree)

  allt = topology.items()
  allt = sortTopologies(allt, ntops if limit <= 0 else len(allt))
  
  if ntops > len(allt) :
    ntops = len(allt)

if not optimizingMethod :
  from biopy import treesSummaries
  if progress:
    print >> sys.stderr, " computing ...,",
    
  # "median-heights", "taxon-partitions"
  if options.method == "clade-ca" :
    stree = treesSummaries.summaryTreeUsingCA(allt[0][1][0], trees)
  elif options.method == "median-heights" :
    stree = treesSummaries.summaryTreeUsingMedianHeights(allt[0][1][0], trees)
  elif options.method == "taxon-partitions" :
    stree = treesSummaries.taxaPartitionsSummaryTree(trees)
  else :
    raise ValueError( options.method )
#  print toNewick(stree)
#  sys.exit(0)
else :  
  if options.topology is not None :
    candidates = [options.topology]
    limit = -1
  elif limit <= 0 :
    if progress:
      pPost = sum([len(x[1]) for x in allt[:ntops]]) / len(trees)
      print >> sys.stderr, """using top %d topologies out of %d, covering %.1f%% of\
   posterior topologies...,""" % \
            (min(ntops, len(allt)), len(allt), 100*pPost),

    k = ntops-1
    lLast = len(allt[k][1])
    while k < len(allt) and lLast == len(allt[k][1]) :
      k += 1
    if k > ntops :
       print >> sys.stderr
       print >> sys.stderr, """*** WARNING ***:  %d additional topologies have the \
  same support as the %dth one (%d trees, %.3f%%)""" % (k - ntops, ntops, lLast, lLast/len(trees))

    candidates = [x[1][0] for x in allt[:ntops]]
  else :
    if 0 :
      import random

      candidates = []
      while len(candidates) < len(allt) :
        l = []
        k = len(candidates)
        lFirst = len(allt[k][1])
        while k < len(allt) and lFirst == len(allt[k][1]) :
          l.append(allt[k][1][0])
          k += 1
        random.shuffle(l)
        # print len(candidates), len(l), lFirst
        candidates.extend(l)

    candidates = [x[1][0] for x in allt]
    #for x in candidates:
    #  print toNewick(x, None, topologyOnly=1)
    print >> sys.stderr, """trying in order %d topologies (time permitting) ...""" \
          % len(candidates),

  bestTree, bestScore = None, float('infinity')

  if progress:
    print >> sys.stderr, "searching ...,",

  nCandidatesTried = 0

  if limit > 0 or progress :
    startTime = time()

  postTrees = trees
  toposTried = dict()
  lesserCans = []

  while len(candidates) :
    tree = candidates.pop(0)
    treeTop = toNewick(tree, topologyOnly=1)
    if treeTop in toposTried:
      continue

    if useOnlyMatchingTopologies:
      k = treeTop
      if k not in topology :
        # no such trees, skip topology
        print >> sys.stderr, "*** skip %s, no posterior trees" % k
        continue
      postTrees = topology[k]
      print >> sys.stderr, "Using %d posterior trees (out of %d) for %s" \
            % (len(postTrees), len(trees), k) 

    nh = getSPlimits(tree) if options.compatible is not None else None

    # Use normalization and init defaults 
    norm = bool(options.norm) # True

    tr, score = minDistanceTree(options.distance, tree, postTrees,
                                nodesMinHeight = nh,
                                withDerivative = options.derivative,
                                norm = norm)
    if tr is None :
      # failed: incompatible
      continue

    if options.distance == biopy.treesPosterior.BRANCH_SCORE_2 :
      # normalize score
      score = (score / len(postTrees)) ** 0.5

    #print 
    #print score,toNewick(tr)

    newBest = False
    if score < bestScore:
      bestScore = score
      bestTree = tr
      newBest = True

      if options.serachTopologies :
        cn = getNNIcandidates(tr)
        if len(cn) :
          # Try the major topology (with all zero-branch NNI applied) next
          candidates = cn[0:1] + candidates
          lesserCans = cn[1:] + lesserCans

    if options.serachTopologies :
      toposTried[treeTop] = True
      # Some hackish criteria: start injecting lesser topologies after 10 main
      # ones between 4 main ones

      if lesserCans and len(toposTried) > 10 and (len(toposTried)) % 4 == 0 :
        candidates.insert(0,lesserCans.pop(0))

    nCandidatesTried += 1
    if limit > 0 or progress :
      timenow = time()
      if progress:
        print >> sys.stderr, \
              ("%ds/%d (%g/%g%s), " % (round(timenow - startTime),
                                  nCandidatesTried, score, bestScore, '*' if newBest else '')),
      if timenow - startTime >= limit > 0 :
        print >> sys.stderr, "time limit reached ...,",
        break
  if limit > 0:
    print >> sys.stderr, "examined %d topologies in %.1f seconds," \
          % (nCandidatesTried, time() - startTime),


  if progress or limit > 0 and (not options.annotate):
    print >>  sys.stderr, "done." 

  if options.score:
    print '[&W %g]' % bestScore,

  stree = bestTree
  
if stree:
  if options.annotate :
    if progress :
       print >>  sys.stderr, "annotating ...," 
    from biopy import treesSummaries
    treesSummaries.annotateTree(stree, trees)

if progress :
  print >>  sys.stderr, "done."
    
print toNewick(stree, attributes="attributes" if options.annotate else None)
